import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited', wvarb = None):
	if wvarb is not None: 
		model =  smf.wls(outcome + ' ~ ' + pbvarb, dataset, weights=dataset[wvarb]).fit(cov_type='HC1')
		coef= model.params[pbvarb]
		se = model.bse[pbvarb]
		return coef, se, model
	else:
		model =  smf.ols(outcome + ' ~ ' + pbvarb, dataset).fit(cov_type='HC1')
		coef= model.params[pbvarb]
		se = model.bse[pbvarb]
		return coef, se, model



## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb='pBlack', outcome='audited', wvarb=None):
	if wvarb is not None:
		return (dataset[pbvarb]*dataset[outcome]*dataset[wvarb]).sum()/(dataset[pbvarb]*dataset[wvarb]).sum()-((1-dataset[pbvarb])*dataset[outcome]*dataset[wvarb]).sum()/((1-dataset[pbvarb])*dataset[wvarb]).sum()
	else:
		return (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()-((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()



def getWVar(values, weights):
	average = np.average(values, weights=weights)
	variance = np.average((values-average)**2, weights=weights)
	return variance

def getSEMultiplier(dataset, pbvarb='pBlack', wvarb=None):
	if wvarb is not None:
		return np.sqrt(getWVar(dataset[pbvarb], dataset[wvarb])/((dataset[pbvarb]*dataset[wvarb]).mean()*((1-dataset[pbvarb])*dataset[wvarb]).mean()))
	else:
		return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

def getSEs(dataset, pbvarb='pBlack', outcome='audited', wvarb=None):
	seMultiplier=getSEMultiplier(dataset,pbvarb,wvarb=wvarb)
	seReg = regEstimate(dataset,pbvarb,outcome, wvarb=wvarb)[1]
	seChen = seReg*seMultiplier
	return seChen, seReg

def w_avg(df, values, weights):
	d = df[values]
	w = df[weights]
	return (d * w).sum()/w.sum()

def get_cond_cov(df, covarb, conditionvarb, outcomevar ,wgtvar):
        #print('startaxpayer_idg len of dataset: ' + str(df.shape[0]))
        df = df.dropna(subset = [covarb, conditionvarb, outcomevar])
        #print('len of data after dropping nulls: ' + str(df.shape[0]))
        conditions = df[conditionvarb].unique()
        totalwt=df[wgtvar].sum()
        #print('total weight is: ' + str(totalwt))
        estimates = []
        #print('conditioning varbs: ' + str(conditions))
        for cond in conditions:
                 estimate = {}
                 estimate['cond_val'] = cond
                 dat = df[df[conditionvarb] == cond]
                 estimate['p_cond'] = (dat[wgtvar]/totalwt).sum()
                 reg = smf.wls(outcomevar+' ~ '+covarb, data= dat, weights= dat[wgtvar]).fit(cov_type='HC1')
                 estimate['var'] = dat[covarb].var()
                 estimate['coef'] = reg.params[covarb]
                 estimates.append(estimate)
        estimates = pd.DataFrame(estimates)
        exp_cond_cov = (estimates['p_cond']*estimates['coef']*estimates['var']).sum()
        return exp_cond_cov

### Operational audits and predicted race probabilitiess
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

nc_BISG = dataBISG[dataBISG['state']=="NC"]
nc_BISG['pprob_black'] = nc_BISG['predicted_prob_black']
nc_BISG = nc_BISG[['taxpayer_id', 'pprob_black']]

##NC data
ncdata = pd.read_csv("/REDACTED/NC_Analysis_Dataset_2014_tplevel.csv")

nc_weight = pd.read_stata('/REDACTED/BIFSG/nc_rewt_2014_coded_output_v4_December_ncsamp.dta')

##########################################################
#### TABLE B1
##########################################################
ncdata = ncdata[ncdata.black_ind.notna()]
nc_weight = nc_weight[['taxpayer_id', 'black_prob', 'uswgt']]

ncdata = pd.merge(ncdata,nc_weight ,on='taxpayer_id')

ncdata = pd.merge(ncdata,nc_BISG,on='taxpayer_id')
ncdata.drop(['predicted_prob_black', 'black_prob'], axis = 1, inplace = True)
ncdata['predicted_prob_black'] = ncdata['pprob_black'] 
ncdata.drop(['pprob_black'], axis = 1, inplace = True)
ncdata['p_black_rd'] = ncdata['predicted_prob_black'].round(2)

ncdata['unitwt']=1
ncdata['aud_no_research_audits'] = [1 if (x.find('[80]') == -1)
                         and (x.find(' 80]') == -1)
                         and (x.find('[80 ') == -1)
                         and (x.find('[91]') == -1)
                         and (x.find(' 91]') == -1)
                         and (x.find('[91 ') == -1)
                         and y == 1
                         else 0
                         for x, y in zip(ncdata.audit_source_code.astype(str), ncdata.audited)]

ncdata['aud_no_research_audits']=pd.to_numeric(ncdata['aud_no_research_audits'])
ncdata['predicted_prob_black']=pd.to_numeric(ncdata['predicted_prob_black'])
ncdata['p_black_rd']=pd.to_numeric(ncdata['p_black_rd'])
ncdata['uswgt']=pd.to_numeric(ncdata['uswgt'])
ncdata['unitwt']=pd.to_numeric(ncdata['unitwt'])
ncdata['black_ind']=pd.to_numeric(ncdata['black_ind'])

### EITC pop
nc_eic_df = ncdata[ncdata["eic_ind"]==1]
### non-EITC pop
nc_non_eic_df = ncdata[ncdata["eic_ind"]==0]


dflist = [ncdata, nc_eic_df, nc_non_eic_df]

sys.stdout = open("/REDACTED/NC_cov_est_full.txt", "w")
print("NC Overall, NC EITC, NC NonEITC")
for dfs in dflist:
        df = dfs.copy()
        ## recalibrate prob_black
        B_b_model = sm.WLS.from_formula("black_ind ~ predicted_prob_black", data = df, weights = df.unitwt).fit(cov_type = "HC1")
        B_b_model_weight = sm.WLS.from_formula("black_ind ~ predicted_prob_black", weights = df.uswgt, data = df).fit(cov_type = "HC1")
        #### rho
        rho = B_b_model.params['predicted_prob_black']
        rho_weight = B_b_model_weight.params['predicted_prob_black']
        df['bhat'] = B_b_model.predict(df['predicted_prob_black'])
        df['epsilonhat'] = B_b_model.resid
        df['bhat_weight'] = B_b_model_weight.predict(df['predicted_prob_black'])
        df['epsilonhat_weight'] = B_b_model_weight.resid
        lin_disp = regEstimate(df, pbvarb = 'bhat', outcome = 'aud_no_research_audits', wvarb='unitwt')
        lin_disp_weight = regEstimate(df, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits', wvarb='uswgt')   
        df['etahat'] = lin_disp[2].resid
        df['etahat_weight'] = lin_disp_weight[2].resid
        print("done")
        #### cov(b, epsilon)
        df['b_eps_product'] =df['epsilonhat']*df['predicted_prob_black']
        df['b_eps_product_weight'] =df['epsilonhat_weight']*df['predicted_prob_black']
        cov_b_eps = w_avg(df, 'b_eps_product', 'unitwt') - w_avg(df, 'epsilonhat', 'unitwt')*w_avg(df, 'predicted_prob_black', 'unitwt')
        cov_b_eps_weight = w_avg(df, 'b_eps_product_weight', 'uswgt') -  w_avg(df, 'epsilonhat_weight', 'uswgt')* w_avg(df, 'predicted_prob_black', 'uswgt') 
        #### Covariance term Cov(E(eta|bhat), E(epsilon|bhat))
        df["bhat_bucket"] = df.bhat.round(2)
        mu_epsilon_dict = df.groupby('bhat_bucket').epsilonhat.mean().to_dict()
        mu_eta_dict = df.groupby('bhat_bucket').etahat.mean().to_dict()
        df['mu_epsilon'] = df.bhat_bucket.map(mu_epsilon_dict).astype(float)
        df['mu_eta'] = df.bhat_bucket.map(mu_eta_dict).astype(float)
        df['eta_epsilon_product'] = df['mu_epsilon'] * df['mu_eta']	
        overall_cov_term = w_avg(df, 'eta_epsilon_product', 'unitwt') - w_avg(df, 'mu_eta', 'unitwt')*w_avg(df, 'mu_epsilon', 'unitwt')
        df["bhat_bucket_weight"] = df.bhat_weight.round(2)
        mu_epsilon_dict_weight = df.groupby('bhat_bucket_weight').apply(w_avg, 'epsilonhat_weight', 'uswgt').to_dict()
        mu_eta_dict_weight = df.groupby('bhat_bucket_weight').apply(w_avg, 'etahat_weight', 'uswgt').to_dict()
        df['mu_epsilon_weight'] = df.bhat_bucket_weight.map(mu_epsilon_dict_weight).astype(float)
        df['mu_eta_weight'] = df.bhat_bucket_weight.map(mu_eta_dict_weight).astype(float)
        df['eta_epsilon_product_weight'] = df['mu_epsilon_weight'] * df['mu_eta_weight']	
        overall_cov_term_weight =  w_avg(df, 'eta_epsilon_product_weight', 'uswgt') - w_avg(df, 'mu_eta_weight', 'uswgt')*w_avg(df, 'mu_epsilon_weight', 'uswgt')
        ### E[cov(Y,b|B)]
        cov_Y_b_B = get_cond_cov(df, 'predicted_prob_black', 'black_ind', 'aud_no_research_audits' ,'unitwt')
        cov_Y_b_B_weight = get_cond_cov(df, 'predicted_prob_black', 'black_ind', 'aud_no_research_audits' ,'uswgt')
        ### E[cov(Y,b*|B)]
        cov_Y_bhat_B = get_cond_cov(df, 'bhat', 'black_ind', 'aud_no_research_audits' ,'unitwt')
        cov_Y_bhat_B_weight = get_cond_cov(df, 'bhat_weight', 'black_ind', 'aud_no_research_audits' ,'uswgt')
        ### E[cov(Y,B|b)]
        cov_Y_B_b = get_cond_cov(df, 'black_ind', 'p_black_rd', 'aud_no_research_audits' ,'unitwt')
        cov_Y_B_b_weight = get_cond_cov(df, 'black_ind', 'p_black_rd', 'aud_no_research_audits' ,'uswgt')
        ### E[cov(Y,B|b*)]
        cov_Y_B_bhat = get_cond_cov(df, 'black_ind', 'bhat_bucket', 'aud_no_research_audits' ,'unitwt')
        cov_Y_B_bhat_weight = get_cond_cov(df, 'black_ind', 'bhat_bucket_weight', 'aud_no_research_audits' ,'uswgt')
        ### Truth
        D = w_avg(df[df['black_ind'] == 1], 'aud_no_research_audits', 'unitwt') - w_avg(df[df['black_ind'] == 0], 'aud_no_research_audits', 'unitwt')
        D_weight = w_avg(df[df['black_ind'] == 1], 'aud_no_research_audits', 'uswgt') - w_avg(df[df['black_ind'] == 0], 'aud_no_research_audits', 'uswgt')
        ### D*_l
        D_l = lin_disp[2].params['bhat']
        D_l_weight = lin_disp_weight[2].params['bhat_weight']
        ### D*_p
        D_p = chenEstimate(df, pbvarb = 'bhat', outcome = 'aud_no_research_audits', wvarb= 'unitwt')
        D_p_weight = chenEstimate(df, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits', wvarb= 'uswgt')
        print("Unweighted")
        rho
        cov_b_eps
        cov_Y_b_B
        cov_Y_B_b
        cov_Y_bhat_B
        cov_Y_B_bhat
        overall_cov_term
        D_l
        D
        D_p
        print("Weighted")
        rho_weight
        cov_b_eps_weight
        cov_Y_b_B_weight
        cov_Y_B_b_weight
        cov_Y_bhat_B_weight
        cov_Y_B_bhat_weight
        overall_cov_term_weight
        D_l_weight
        D_weight
        D_p_weight
        print("\n\n\n")

sys.stdout = sys.__stdout__